import torch

if __name__ == '__main__':
	############### 3维 #################
	logits = torch.FloatTensor([
		[[1, 3, 2], [1, 2, 3], [3, 1, 2]],
		[[3, 1, 2], [1, 3, 2], [2, 1, 3]],
		[[1, 3, 4], [1, 4, 3], [3, 1, 2]],
	])
	y = torch.LongTensor([
		[1, 2, 0],
		[0, 1, 2],
		[2, 1, 0]
	])
	print(logits.argmax(-1))
